"""
Created on February 4, 2017

@author: optas

Edited by itailang
"""
from __future__ import print_function

from builtins import str
from builtins import range
import tensorflow as tf
import numpy as np
import warnings

from tflearn.layers.core import fully_connected, dropout
from tflearn.layers.conv import conv_1d, avg_pool_1d
from tflearn.layers.normalization import batch_normalization
from tflearn.layers.core import fully_connected, dropout

from .tf_utils import expand_scope_by_name, replicate_parameter_for_all_layers


def encoder_with_convs_and_symmetry(
    in_signal,
    n_filters=[64, 128, 256, 1024],
    filter_sizes=[1],
    strides=[1],
    b_norm=True,
    b_norm_decay=0.9,
    non_linearity=tf.nn.relu,
    regularizer=None,
    weight_decay=0.001,
    symmetry=tf.reduce_max,
    dropout_prob=None,
    pool=avg_pool_1d,
    pool_sizes=None,
    scope=None,
    reuse=False,
    padding="same",
    verbose=False,
    closing=None,
    conv_op=conv_1d,
    return_layer_before_symmetry=False,
):
    """An Encoder (recognition network), which maps inputs onto a latent space.
    """

    if verbose:
        print("Building Encoder")

    n_layers = len(n_filters)
    filter_sizes = replicate_parameter_for_all_layers(filter_sizes, n_layers)
    strides = replicate_parameter_for_all_layers(strides, n_layers)
    dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers)

    if n_layers < 2:
        raise ValueError("More than 1 layers are expected.")

    for i in range(n_layers):
        if i == 0:
            layer = in_signal

        name = "encoder_conv_layer_" + str(i)
        scope_i = expand_scope_by_name(scope, name)
        layer = conv_op(
            layer,
            nb_filter=n_filters[i],
            filter_size=filter_sizes[i],
            strides=strides[i],
            regularizer=regularizer,
            weight_decay=weight_decay,
            name=name,
            reuse=reuse,
            scope=scope_i,
            padding=padding,
        )

        if verbose:
            print(
                (
                    name,
                    "conv params = ",
                    np.prod(layer.W.get_shape().as_list())
                    + np.prod(layer.b.get_shape().as_list()),
                )
            )

        if b_norm:
            name += "_bnorm"
            scope_i = expand_scope_by_name(scope, name)
            layer = batch_normalization(
                layer, decay=b_norm_decay, name=name, reuse=reuse, scope=scope_i
            )
            if verbose:
                print(
                    (
                        "bnorm params = ",
                        np.prod(layer.beta.get_shape().as_list())
                        + np.prod(layer.gamma.get_shape().as_list()),
                    )
                )

        if non_linearity is not None:
            layer = non_linearity(layer)

        if pool is not None and pool_sizes is not None:
            if pool_sizes[i] is not None:
                layer = pool(layer, kernel_size=pool_sizes[i])

        if dropout_prob is not None and dropout_prob[i] > 0:
            layer = dropout(layer, 1.0 - dropout_prob[i])

        if verbose:
            print(layer)
            print(("output size:", np.prod(layer.get_shape().as_list()[1:]), "\n"))

    layer_before_symmetry = layer
    if symmetry is not None:
        layer = symmetry(layer, axis=1)
        if verbose:
            print(layer)

    if closing is not None:
        layer = closing(layer)
        print(layer)

    if return_layer_before_symmetry:
        return layer, layer_before_symmetry
    else:
        return layer


def decoder_with_fc_only(
    latent_signal,
    layer_sizes=[],
    b_norm=True,
    b_norm_decay=0.9,
    non_linearity=tf.nn.relu,
    regularizer=None,
    weight_decay=0.001,
    reuse=False,
    scope=None,
    dropout_prob=None,
    b_norm_finish=False,
    b_norm_decay_finish=0.9,
    verbose=False,
):
    """A decoding network which maps points from the latent space back onto the data space.
    """
    if verbose:
        print("Building Decoder")

    n_layers = len(layer_sizes)
    dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers)

    if n_layers < 2:
        raise ValueError("For an FC decoder with single a layer use simpler code.")

    for i in range(0, n_layers - 1):
        name = "decoder_fc_" + str(i)
        scope_i = expand_scope_by_name(scope, name)

        if i == 0:
            layer = latent_signal

        layer = fully_connected(
            layer,
            layer_sizes[i],
            activation="linear",
            weights_init="xavier",
            name=name,
            regularizer=regularizer,
            weight_decay=weight_decay,
            reuse=reuse,
            scope=scope_i,
        )

        if verbose:
            print(
                (
                    name,
                    "FC params = ",
                    np.prod(layer.W.get_shape().as_list())
                    + np.prod(layer.b.get_shape().as_list()),
                )
            )

        if b_norm:
            name += "_bnorm"
            scope_i = expand_scope_by_name(scope, name)
            layer = batch_normalization(
                layer, decay=b_norm_decay, name=name, reuse=reuse, scope=scope_i
            )
            if verbose:
                print(
                    (
                        "bnorm params = ",
                        np.prod(layer.beta.get_shape().as_list())
                        + np.prod(layer.gamma.get_shape().as_list()),
                    )
                )

        if non_linearity is not None:
            layer = non_linearity(layer)

        if dropout_prob is not None and dropout_prob[i] > 0:
            layer = dropout(layer, 1.0 - dropout_prob[i])

        if verbose:
            print(layer)
            print(("output size:", np.prod(layer.get_shape().as_list()[1:]), "\n"))

    # Last decoding layer never has a non-linearity.
    name = "decoder_fc_" + str(n_layers - 1)
    scope_i = expand_scope_by_name(scope, name)
    layer = fully_connected(
        layer,
        layer_sizes[n_layers - 1],
        activation="linear",
        weights_init="xavier",
        name=name,
        regularizer=regularizer,
        weight_decay=weight_decay,
        reuse=reuse,
        scope=scope_i,
    )
    if verbose:
        print(
            (
                name,
                "FC params = ",
                np.prod(layer.W.get_shape().as_list())
                + np.prod(layer.b.get_shape().as_list()),
            )
        )

    if b_norm_finish:
        name += "_bnorm"
        scope_i = expand_scope_by_name(scope, name)
        layer = batch_normalization(
            layer, decay=b_norm_decay_finish, name=name, reuse=reuse, scope=scope_i
        )
        if verbose:
            print(
                (
                    "bnorm params = ",
                    np.prod(layer.beta.get_shape().as_list())
                    + np.prod(layer.gamma.get_shape().as_list()),
                )
            )

    if verbose:
        print(layer)
        print(("output size:", np.prod(layer.get_shape().as_list()[1:]), "\n"))

    return layer


def decoder_with_convs_only(
    in_signal,
    n_filters,
    filter_sizes,
    strides,
    padding="same",
    b_norm=True,
    non_linearity=tf.nn.relu,
    conv_op=conv_1d,
    regularizer=None,
    weight_decay=0.001,
    dropout_prob=None,
    upsample_sizes=None,
    b_norm_finish=False,
    scope=None,
    reuse=False,
    verbose=False,
):

    if verbose:
        print("Building Decoder")

    n_layers = len(n_filters)
    filter_sizes = replicate_parameter_for_all_layers(filter_sizes, n_layers)
    strides = replicate_parameter_for_all_layers(strides, n_layers)
    dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers)

    for i in range(n_layers):
        if i == 0:
            layer = in_signal

        name = "decoder_conv_layer_" + str(i)
        scope_i = expand_scope_by_name(scope, name)

        layer = conv_op(
            layer,
            nb_filter=n_filters[i],
            filter_size=filter_sizes[i],
            strides=strides[i],
            padding=padding,
            regularizer=regularizer,
            weight_decay=weight_decay,
            name=name,
            reuse=reuse,
            scope=scope_i,
        )

        if verbose:
            print(
                (
                    name,
                    "conv params = ",
                    np.prod(layer.W.get_shape().as_list())
                    + np.prod(layer.b.get_shape().as_list()),
                )
            )

        if (b_norm and i < n_layers - 1) or (i == n_layers - 1 and b_norm_finish):
            name += "_bnorm"
            scope_i = expand_scope_by_name(scope, name)
            layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i)
            if verbose:
                print(
                    (
                        "bnorm params = ",
                        np.prod(layer.beta.get_shape().as_list())
                        + np.prod(layer.gamma.get_shape().as_list()),
                    )
                )

        if (
            non_linearity is not None and i < n_layers - 1
        ):  # Last layer doesn't have a non-linearity.
            layer = non_linearity(layer)

        if dropout_prob is not None and dropout_prob[i] > 0:
            layer = dropout(layer, 1.0 - dropout_prob[i])

        if upsample_sizes is not None and upsample_sizes[i] is not None:
            layer = tf.tile(layer, multiples=[1, upsample_sizes[i], 1])

        if verbose:
            print(layer)
            print(("output size:", np.prod(layer.get_shape().as_list()[1:]), "\n"))

    return layer
